#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################
# Some parts of the code are borrowed from: https://github.com/pytorch/vision
# with the following license:
#
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################
#For MobileNetV2 model from https://github.com/ericsun99/MobileNet-V2-Pytorch
#
# BSD 2-Clause License
#
# Copyright (c) 2018, ericsun99
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#  this list of conditions and the following disclaimer in the documentation
#  and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#################################################################################


import torch
import math
from ... import xnn

###################################################
__all__ = ['get_config', 'MobileNetV2EricsunBase', 'MobileNetV2Ericsun', 'mobilenet_v2_ericsun']


###################################################
def get_config():
    model_config = xnn.utils.ConfigNode()
    model_config.input_channels = 3
    model_config.num_classes = 1000
    model_config.width_mult = 1.
    model_config.expand_ratio = 6
    model_config.strides = (2,2,2,2,2)
    model_config.activation = xnn.layers.DefaultAct2d
    model_config.kernel_size = 3
    model_config.dropout = False
    model_config.linear_dw = False
    model_config.layer_setting = None
    return model_config


###################################################
def conv_bn(inp, oup, stride, activation, kernel_size=3):
    return torch.nn.Sequential(
        torch.nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, bias=False),
        xnn.layers.DefaultNorm2d(oup),
        activation(inplace=True)
    )

###################################################
def conv(inp, oup, stride, activation, kernel_size=3):
    return torch.nn.Sequential(
        torch.nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, bias=False),
        activation(inplace=True)
    )

def conv_1x1_bn(inp, oup, activation, groups=1):
    return torch.nn.Sequential(
        torch.nn.Conv2d(inp, oup, 1, 1, 0, groups=groups, bias=False),
        xnn.layers.DefaultNorm2d(oup),
        activation(inplace=True)
    )


def width_multiplier(value, base=8, min_val=8):
  min_val = base if (min_val is None) else min_val
  value = int(math.floor(float(value) / base + 0.5) * base)
  value = max(value, min_val) if min_val else value
  value = int(value)
  return value

  
###################################################
class InvertedResidual(torch.nn.Module):
    def __init__(self, input_channels, output_channels, stride, expand_ratio, activation=None, kernel_size=3, linear_dw=False):
        super(InvertedResidual, self).__init__()
        self.stride = stride

        self.use_res_connect = (self.stride == 1 and input_channels == output_channels)
        intermediate_channels = input_channels * expand_ratio
        conv = [
                # pw
                torch.nn.Conv2d(input_channels, intermediate_channels, 1, 1, 0, bias=False),
                xnn.layers.DefaultNorm2d(intermediate_channels),
                activation(inplace=True),
                # dw
                torch.nn.Conv2d(intermediate_channels, intermediate_channels, kernel_size, stride, kernel_size//2, groups=intermediate_channels, bias=False),
                xnn.layers.BypassBlock() if linear_dw else xnn.layers.DefaultNorm2d(input_channels * expand_ratio),
                xnn.layers.BypassBlock() if linear_dw else activation(inplace=True),
                # pw-linear
                torch.nn.Conv2d(intermediate_channels, output_channels, 1, 1, 0, bias=False),
                xnn.layers.DefaultNorm2d(output_channels)
            ]
        if linear_dw:
            conv.append(activation(inplace=True))

        self.conv = torch.nn.Sequential(*conv)
        if self.use_res_connect:
            self.add = xnn.layers.AddBlock(signed=True)

    def forward(self, x):
        x1 = self.conv(x)
        if self.use_res_connect:
            x1 = self.add((x, x1))

        return x1


###################################################
class MobileNetV2EricsunBase(torch.nn.Module):
    def __init__(self, ResidualBlock, model_config):
        super().__init__()
        self.model_config = model_config
        self.num_classes = self.model_config.num_classes

        # strides of various layers
        s0 = model_config.strides[0]
        s1 = model_config.strides[1]
        s2 = model_config.strides[2]
        s3 = model_config.strides[3]
        s4 = model_config.strides[4]

        # setting of inverted residual blocks
        if self.model_config.layer_setting is None:
            expand_ratio = self.model_config.expand_ratio
            self.model_config.layer_setting = [
                # t,            c,  n, s
                [1,             32, 1, s0],
                [1,             16, 1,  1],
                [expand_ratio,  24, 2, s1],
                [expand_ratio,  32, 3, s2],
                [expand_ratio,  64, 4, s3],
                [expand_ratio,  96, 3,  1],
                [expand_ratio, 160, 3, s4],
                [expand_ratio, 320, 1,  1],
                [1,           1280, 1,  1],
            ]

        # building first layer
        stride = self.model_config.layer_setting[0][3]
        output_channels = width_multiplier(self.model_config.layer_setting[0][1]*self.model_config.width_mult)
        features = [conv_bn(self.model_config.input_channels, output_channels, stride, self.model_config.activation, kernel_size=3)]
        channels = output_channels

        # building inverted residual blocks
        for t, c, n, s in self.model_config.layer_setting[1:-1]:
            output_channels = width_multiplier(c*self.model_config.width_mult)
            for i in range(n):
                stride = (s if i == 0 else 1)
                block = ResidualBlock(channels, output_channels, stride, t, self.model_config.activation, \
                                      kernel_size=self.model_config.kernel_size, linear_dw=self.model_config.linear_dw)
                features.append(block)
                channels = output_channels
            #
        #

        # building classifier
        if self.model_config.num_classes != None:
            output_channels = width_multiplier(self.model_config.layer_setting[-1][1]*self.model_config.width_mult)
            features.append(conv_1x1_bn(channels, output_channels, self.model_config.activation))
            features.append(torch.nn.AdaptiveAvgPool2d(1))
            channels = output_channels

            # building classifier
            self.classifier = torch.nn.Sequential(
                torch.nn.Dropout(p=0.2, inplace=True) if self.model_config.dropout else xnn.layers.BypassBlock(),
                torch.nn.Linear(channels, self.model_config.num_classes),
            )

        # make it torch.nn.Sequential
        self.features = torch.nn.Sequential(*features)

        self._initialize_weights()


    def forward(self, x):
        for block_id, block in enumerate(self.features):
            # TODO: Cleanup. It should not be done in this complicated way.
            # To print the correct size of features.
            if isinstance(block, torch.nn.AdaptiveAvgPool2d):
                xnn.utils.print_once('=> feature size is: ', x.size())
            #
            x = block(x)

        if self.num_classes is not None:
            x = torch.flatten(x, 1)
            x = self.classifier(x)

        return x


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, torch.nn.BatchNorm2d):
                if m.weight is not None:
                    torch.nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)
            elif isinstance(m, torch.nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


class MobileNetV2Ericsun(MobileNetV2EricsunBase):
    def __init__(self, model_config):
        model_config = get_config().merge_from(model_config)
        super().__init__(InvertedResidual, model_config)


###################################################
class mobilenet_v2_ericsun(MobileNetV2Ericsun):
    def __init__(self, model_config, pretrained):
        super().__init__(model_config)
        if pretrained:
            _ = xnn.utils.load_weights(self, pretrained)

